[perf] Improve performance for putting jagged tensor#36
[perf] Improve performance for putting jagged tensor#360oshowero0 merged 8 commits intoAscend:mainfrom
Conversation
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
CLA Signature Pass0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
There was a problem hiding this comment.
Pull request overview
This PR targets a performance bottleneck when put_data processes TensorDict fields backed by jagged (nested) tensors by avoiding repeated expensive multi-indexing on jagged tensors.
Changes:
- Optimize
_filter_storage_datato unbind jagged tensors before applyingitemgetterover multiple batch indexes. - Add a note in
KVStorageManager._generate_valuesindicating a similar potential optimization for jagged tensors.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
transfer_queue/storage/managers/simple_backend_manager.py |
Adds a jagged-tensor fast path in _filter_storage_data by unbinding before multi-index selection. |
transfer_queue/storage/managers/base.py |
Adds a TODO note in _generate_values related to jagged tensor handling. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
CLA Signature Pass0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # unbind jagged tensor | ||
| results: dict = {} | ||
| for field in sorted(data.keys()): | ||
| field_data = data[field] | ||
|
|
||
| # For jagged tensors, unbind() first to accelerate indexing process | ||
| if isinstance(field_data, Tensor) and field_data.layout == torch.jagged: | ||
| results[field] = field_data.unbind() | ||
| else: | ||
| results[field] = field_data |
There was a problem hiding this comment.
This change introduces a jagged-tensor fast path (pre-unbind before indexing), but there’s no test exercising put_data with layout=torch.jagged. Adding a unit test that uses a jagged tensor field and asserts the data sent to _put_to_single_storage_unit matches expected samples would prevent regressions (and ensure the performance fix stays wired in).
CLA Signature Pass0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # unbind jagged tensor | ||
| results: dict = {} | ||
| for field in sorted(data.keys()): | ||
| field_data = data[field] | ||
|
|
||
| # For jagged tensors, unbind() first to accelerate indexing process | ||
| if isinstance(field_data, Tensor) and field_data.layout == torch.jagged: | ||
| results[field] = field_data.unbind() | ||
| else: | ||
| results[field] = field_data | ||
|
|
There was a problem hiding this comment.
This adds a new jagged-tensor fast path (unbind() before indexing), but there isn't a unit test exercising it. Consider extending the existing tests/test_async_simple_storage_manager.py::test_async_storage_manager_mock_operations to include a layout=torch.jagged nested tensor and assert unbind() is called and that _put_to_single_storage_unit receives the expected sliced items.
CLA Signature Pass0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
CLA Signature Pass0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
CLA Signature Pass0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
1 similar comment
CLA Signature Pass0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
e3dc052 to
224719e
Compare
CLA Signature Pass0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
Background
When users input a TensorDict containing jagged tensors (nested tensors), the
put_dataprocess becomes extremely slow.Specifically, the
_filter_storage_datafunction usesitemgetter(*batch_indexes)(data[fname])to extract individual items from each tensor in the TensorDict. This indexing approach works efficiently for strided tensors but is extremely inefficient for jagged tensors.Root Cause
For jagged tensors, itemgetter with multiple batch indexes requires repeated indexing operations, which is$\mathcal{O}(n)$ for each access. When extracting multiple samples, this becomes $\mathcal{O}(n²)$ complexity.
Solution
We unbind nested tensor before accessing each sample from it.
Simple Reproduction Script
Output: